---
title: "Fast Cold Starts with Cached Weights"
description: "Deploy a language model, with the model weights cached at build time"
---


        <Card
          title="View on Github"
          icon="github" href="https://github.com/basetenlabs/truss-examples/tree/main/06-high-performance-cached-weights">
        </Card>

In this example, we go through a Truss that serves an LLM, and _caches_ the weights
at build time. Loading model weights for any model can often be the most time-consuming
part of starting a model. Caching the weights at build time means that the weights
will be baked into the Truss image, and will be available _immediately_ when your model
replica starts. This means that **cold starts** will be _significantly faster_ with this approach.

# Implementing the `Model` class

With weight caching, you don't have to change anything about how the `Model` class
is implemented to take advantage of the weight caching.

```python model/model.py
from typing import Dict, List

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"


def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"


class Model:
    def __init__(self, **kwargs) -> None:
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = LlamaForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
        )
        self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)

    def predict(self, request: Dict) -> Dict[str, List]:
        prompt = request.pop("prompt")
        input_ids = self.tokenizer(
            format_prompt(prompt), return_tensors="pt"
        ).input_ids.cuda()

        outputs = self.model.generate(
            inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
        )
        response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        return {"response": response}
```

# Setting up the config.yaml

The `config.yaml` file is where you need to include the changes to
actually cache the weights at build time.

```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
  example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4
```
# Configuring the model_cache

To cache model weights, set the `model_cache` key.
The `repo_id` field allows you to specify a Huggingface
repo to pull down and cache at build-time, and the `ignore_patterns`
field allows you to specify files to ignore. If this is specified, then
this repo won't have to be pulled during runtime.

Check out the [guide](https://truss.baseten.co/guides/model-cache) for more info.

```yaml config.yaml
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
  ignore_patterns:
  - "*.bin"

```
The remaining config options are again, similar to what you would
configure for the model without the weight caching.

```yaml config.yaml
resources:
  cpu: "4"
  memory: 30Gi
  use_gpu: True
  accelerator: A10G
secrets: {}
```
# Deploy the model

Deploy the model like you would other Trusses, with:
```bash
$ truss push
```
 <Note>
 The build step will take  longer than with the normal
 Llama Truss, since bundling the model weights is now happening during the build.
 The deploy step & scale-ups will happen much faster with this approach.
 </Note>

You can then invoke the model with:
```bash
$ truss predict -d '{"inputs": "What is a large language model?"}'
```

<RequestExample>
```python model/model.py
from typing import Dict, List

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"


def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"


class Model:
    def __init__(self, **kwargs) -> None:
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = LlamaForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
        )
        self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)

    def predict(self, request: Dict) -> Dict[str, List]:
        prompt = request.pop("prompt")
        input_ids = self.tokenizer(
            format_prompt(prompt), return_tensors="pt"
        ).input_ids.cuda()

        outputs = self.model.generate(
            inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
        )
        response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        return {"response": response}
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
  example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
  ignore_patterns:
  - "*.bin"

resources:
  cpu: "4"
  memory: 30Gi
  use_gpu: True
  accelerator: A10G
secrets: {}
```
</RequestExample>
